package com.hoppinzq;

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONException;
import com.alibaba.fastjson.JSONObject;
import com.hoppinzq.embedBean.Embed;
import com.hoppinzq.embedBean.EmbeddingData;
import org.apache.commons.io.FileUtils;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.RealVector;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.JedisPool;

import java.io.File;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;

/**
 * demo1：演示chargpt根据搜索内容推荐
 * 在RedisDemo填写redis的连接信息
 * 在Embedding填写apikey
 */
public class demo {

    public static void main(String[] args) {
        initQuestion();
        getQuestion("java如何连接mysql？",10);
    }

    /**
     * 初始化博客内容（演示的内容就先写死）
     * redis的key为demo:blogs，如果没有就新增到redis里面
     * redis在这里模拟的是搜索引擎的索引库（Lucene,ES）
     * chatGPT给出内容的embedding
     */
    private static void initQuestion(){
        RedisDemo redisDemo=new RedisDemo();
        JedisPool jedisPool=redisDemo.getJedisPool();

        try (Jedis jedis = jedisPool.getResource()) {
            boolean exists = jedis.exists("demo:blogs");
            if(exists)
                return;

            File textFile = new File(Thread.currentThread().getContextClassLoader().getResource("blog.json").getPath().replace("%20", " "));
            String text = FileUtils.readFileToString(textFile, "utf-8");
            JSONArray blogs= JSON.parseArray(text);
            List<String> blog2Embeddings=new ArrayList<>();
            for(int blogIndex=0;blogIndex<blogs.size();blogIndex++){
                JSONObject blog=blogs.getJSONObject(blogIndex);
                blog2Embeddings.add(blog.get("blog_title").toString()+blog.get("blog__content").toString());
            }
            Embed embedAll=Embedding.getEmbed(blog2Embeddings);
            EmbeddingData[] embeddingAllQuestions=embedAll.getData();
            for(int i=0;i<blogs.size();i++){
                JSONObject blog=blogs.getJSONObject(i);
                blog.put("embedding",embeddingAllQuestions[i].getEmbedding());
                jedis.rpush("demo:blogs", blog.toJSONString());
            }
        } catch (Exception e) {
            e.printStackTrace();
        } finally {
            redisDemo.close();
        }
    }

    /**
     * 获取问题，这里模拟的大数据或云计算平台
     * 因为你可以看到，我是拿问的问题去跟所有问题的embedding求余弦相似度，如果有百万或亿级数据的话，要去求数百万次余弦相似度（这里可以用向量化计算或者并行计算的算法去优化，也可以简单的开辟多个线程求）。
     * @param question
     */
    private static void getQuestion(String question,int number){

        RedisDemo redisDemo=new RedisDemo();
        JedisPool jedisPool=redisDemo.getJedisPool();

        try (Jedis jedis = jedisPool.getResource()) {
            boolean exists = jedis.exists("demo:blogs");
            if(!exists)
                return;
            Embed embedQuestions=Embedding.getEmbed(question);
            double[] embedding=embedQuestions.getData()[0].getEmbedding();
            List<String> blogs=jedis.lrange("demo:blogs", 0, -1);
            List<JSONObject> chatSimilarity=new ArrayList<>();
            for(int i=0;i<blogs.size();i++){
                JSONObject jsonObject=JSONObject.parseObject(blogs.get(i));
                JSONArray embeddingArrays=(JSONArray)jsonObject.get("embedding");
                jsonObject.put("embedding",cosineSimilarity(embedding,getDoubleArray(embeddingArrays)));
                chatSimilarity.add(jsonObject);
            }

            /**
             * 排序，从大到小
             */
            Comparator<JSONObject> comparator = new Comparator<JSONObject>() {
                @Override
                public int compare(JSONObject o1, JSONObject o2) {
                    try {
                        double embedding1 = o1.getDouble("embedding");
                        double embedding2 = o2.getDouble("embedding");
                        if (embedding1 > embedding2) {
                            return -1;
                        } else if (embedding1 < embedding2) {
                            return 1;
                        } else {
                            return 0;
                        }
                    } catch (JSONException e) {
                        e.printStackTrace();
                        return 0;
                    }
                }
            };
            Collections.sort(chatSimilarity, comparator);
            System.err.println("问题："+question);
            System.err.println("下面是推荐的10个博客：");
            System.err.println(formatJson(JSONArray.toJSONString(chatSimilarity.subList(0,number))));
        } catch (Exception e) {
            e.printStackTrace();
        } finally {
            redisDemo.close();
        }
    }

    /**
     * 余弦相似度算法
     * @param vector1
     * @param vector2
     * @return
     */
    private static double cosineSimilarity(double[] vector1, double[] vector2) {
        RealVector realVector1 = new ArrayRealVector(vector1);
        RealVector realVector2 = new ArrayRealVector(vector2);
        double dotProduct = realVector1.dotProduct(realVector2);
        double norm = realVector1.getNorm() * realVector2.getNorm();
        double similarity = dotProduct / norm;
        return similarity;
    }

    /**
     * jsonarray转double数组
     * @param jsonArray
     * @return
     */
    private static double[] getDoubleArray(JSONArray jsonArray){
        double[] array = new double[jsonArray.size()];
        for (int i = 0; i < jsonArray.size(); i++) {
            try {
                array[i] = jsonArray.getDouble(i);
            } catch (JSONException e) {
                e.printStackTrace();
                array[i] = 0.0;
            }
        }
        return array;
    }

    /**
     * 格式化json
     * @param jsonStr
     * @return
     */
    private static String formatJson(String jsonStr) {
        if (null == jsonStr || "".equals(jsonStr)) {
            return "";
        }
        StringBuilder sb = new StringBuilder();
        char last = '\0';
        char current = '\0';
        int level = 0;
        boolean indentFlag = false;
        for (int i = 0; i < jsonStr.length(); i++) {
            last = current;
            current = jsonStr.charAt(i);
            switch (current) {
                case '{':
                case '[':
                    sb.append(current);
                    sb.append('\n');
                    level++;
                    addIndentBlank(sb, level);
                    break;
                case '}':
                case ']':
                    sb.append('\n');
                    level--;
                    addIndentBlank(sb, level);
                    sb.append(current);
                    break;
                case ',':
                    sb.append(current);
                    if (last != '\\') {
                        sb.append('\n');
                        addIndentBlank(sb, level);
                    }
                    break;
                default:
                    sb.append(current);
            }
        }
        return sb.toString();
    }

    /**
     * 添加制表符
     * @param sb
     * @param level
     */
    private static void addIndentBlank(StringBuilder sb, int level) {
        for (int i = 0; i < level; i++) {
            sb.append('\t');
        }
    }
}
